import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

import argparse
import time
import warnings

import sys
sys.path.append('../code')
from classification import *

from pathlib import Path
import pickle

def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--params", nargs='+', default=None)

    args = parser.parse_args()

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':
    warnings.simplefilter('ignore')
    
    # set the seed
    args = parse_commandline()
    seed = int(args.params[0])
    np.random.seed(seed)

    # set the number of bootstrap samples
    n_bs = 10000

    # whether to balance the train set size between the groups
    ## 1.0 means perfectly balanced
    ## None means no balancing
    ratio = None
    test_size = 0.33
    
    # load the data
    df = pd.read_csv('../data/df_health_risk.csv')
    Y_column = 'gagne_sum_t'
    G_column = 'dem_race_black'
    Y_columns = ['cost_t', 'gagne_sum_t']
    X_column = list(set(df.columns) - set(Y_columns) - set([G_column]))
    X_column.sort()

    y = df[Y_column].values
    thres_percentile = 97
    y_thres = np.percentile(y, thres_percentile)
    y_binary = (df['gagne_sum_t'].values > y_thres)*1
    df[Y_column] = y_binary

    clf_b = LogisticRegression(max_iter=1000)
    clf_r = LogisticRegression(max_iter=1000)

    res, X_test_blue, y_test_blue, X_test_red, y_test_red =\
        classification_onefold(df, Y_column, G_column, X_column, clf_b, clf_r,
                            seed=seed, ratio=ratio, verbose=True, test_size=test_size,
                            standardize=False)

    clf_b = res['clf_b']
    e_b_B = res['res_blue']['e_b']
    e_r_B = res['res_blue']['e_r']

    n_blue = X_test_blue.shape[0]
    n_red = X_test_red.shape[0]
    e_b_B_list = []
    e_r_B_list = []
    X_test_blue.reset_index(drop=True, inplace=True)
    X_test_red.reset_index(drop=True, inplace=True)
    for _ in range(n_bs):
        idx_blue = np.random.choice(n_blue, n_blue, replace=True)
        idx_red = np.random.choice(n_red, n_red, replace=True)
        X_test_blue_bs = X_test_blue.iloc[idx_blue]
        y_test_blue_bs = y_test_blue[idx_blue]
        X_test_red_bs = X_test_red.iloc[idx_red]
        y_test_red_bs = y_test_red[idx_red]

        res_blue = est_group_loss(clf_b, X_test_blue_bs, y_test_blue_bs,
                                  X_test_red_bs, y_test_red_bs)
        e_b_B_list.append(res_blue['e_b'])
        e_r_B_list.append(res_blue['e_r'])
    e_b_B_list = np.array(e_b_B_list)
    e_r_B_list = np.array(e_r_B_list)

    # compute p-values
    r_skewed_list = e_b_B_list - e_r_B_list
    r_skewed = e_b_B - e_r_B
    p_b = np.mean(r_skewed_list - r_skewed > r_skewed)
    
    # save the results
    filepath = '/home/kop5674/FA_frontier/result/science/'
    filename = f'test_science_logreg_{seed}_{ratio}.pkl'
    res = {'p_b': p_b, 'e_b_R': e_b_B, 'e_r_R': e_r_B}
    
    print(res)
    # print(f'p_b: {p_b}')
    
    with open(filepath + filename, 'wb') as f:
        pickle.dump(res, f)
    
    # np.save(f'obermeyer_e_b_R_list_{seed}_{n_bs}.npy', e_b_B_list)
    # np.save(f'obermeyer_e_r_R_list_{seed}_{n_bs}.npy', e_r_B_list)